{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NeMo Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import NeMo and ASR collection\n",
    "import nemo\n",
    "import nemo.collections.asr as nemo_asr\n",
    "try:\n",
    "    nf = nemo.core.NeuralModuleFactory()\n",
    "except:\n",
    "    print(\"GPU was not detected. Running on CPU\")\n",
    "    nf = nemo.core.NeuralModuleFactory(placement=nemo.core.DeviceType.CPU)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NeMoModel instantiation without pre-trained weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A *NeMoModel* is a kind of NeuralModule which contains other neural modules inside it.\n",
    "NeMoModel can have other NeuralModules inside and their mode, and topology of connections can\n",
    "depend on the mode in which NeMo model is used (training or evaluation)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Because NeMoModel is a NeuralModule, regular constructor-based initialization applies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NEMO_GIT_ROOT='<REPLACE_WITH_PATH_TO_NEMO_GIT_ROOT>'\n",
    "#First, load the config from YAML file\n",
    "from ruamel.yaml import YAML\n",
    "yaml = YAML(typ=\"safe\")\n",
    "with open(NEMO_GIT_ROOT+\"/examples/asr/configs/jasper_an4.yaml\") as file:\n",
    "    model_definition = yaml.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quartznet_model1 = nemo.collections.asr.models.QuartzNet(\n",
    "    preprocessor_params=model_definition['AudioToMelSpectrogramPreprocessor'],\n",
    "    encoder_params=model_definition['JasperEncoder'],\n",
    "    decoder_params=model_definition['JasperDecoderForCTC'])\n",
    "print(f\"Created QuartzNet model with {quartznet_model1.num_weights} weights\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Because NeMoModel is a NeuralModule, regular config import/export work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quartznet_model1.export_to_config(\"qn1.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quartznet_model2 = nemo.collections.asr.models.QuartzNet.import_from_config(config_file=\"qn1.yaml\")\n",
    "print(f\"Created QuartzNet model with {quartznet_model2.num_weights} weights\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NeMoModel instantiation with pre-trained weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List all available models from NGC\n",
    "for checkpoint in nemo.collections.asr.models.ASRConvCTCModel.list_pretrained_models():\n",
    "    print(checkpoint.pretrained_model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Automagically go to NGC and instantiate a model and weights\n",
    "quartznet_model3 = nemo_asr.models.QuartzNet.from_pretrained(model_info=\"QuartzNet15x5-En\")\n",
    "print(f\"Created QuartzNet model with {quartznet_model3.num_weights} weights\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export model to \".nemo\" format"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export to \".nemo\" file - all params, structure and weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quartznet_model3.save_to('quartznet.nemo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quartznet_model4 = nemo_asr.models.QuartzNet.from_pretrained(model_info='quartznet.nemo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \".nemo\" file is just an arxiv with all of the model's details and weights\n",
    "! mv quartznet.nemo quartznet.tar.gz\n",
    "! tar -xvf quartznet.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export to \".nemo\" file - for deployment with NVIDIA Jarvis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quartznet_model3.save_to('quartznet_for_Jarvis.nemo', optimize_for_deployment=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \".nemo\" file optimized for deployment will only contain eval structure and .onnx files\n",
    "! mv quartznet_for_Jarvis.nemo quartznet_for_Jarvis.nemo.tar.gz\n",
    "! tar -xvf quartznet_for_Jarvis.nemo.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NeMoModels can be used just as any other Neural Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Change these to point to your training data\n",
    "train_manifest = \"/Users/okuchaiev/Data/an4_dataset/an4_train.json\"\n",
    "val_manifest = \"/Users/okuchaiev/Data/an4_dataset/an4_val.json\"\n",
    "labels = model_definition['labels']\n",
    "data_layer = nemo_asr.AudioToTextDataLayer(manifest_filepath=train_manifest, labels=labels, batch_size=16)\n",
    "data_layerE = nemo_asr.AudioToTextDataLayer(manifest_filepath=val_manifest, labels=labels, batch_size=16)\n",
    "ctc_loss = nemo_asr.CTCLossNM(num_classes=len(labels))\n",
    "greedy_decoder = nemo_asr.GreedyCTCDecoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "audio_signal, audio_signal_len, transcript, transcript_len = data_layer()\n",
    "log_probs, encoded_len = quartznet_model4(input_signal=audio_signal, length=audio_signal_len)\n",
    "predictions = greedy_decoder(log_probs=log_probs)\n",
    "loss = ctc_loss(log_probs=log_probs, targets=transcript,\n",
    "                input_length=encoded_len, target_length=transcript_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# START TRAINING \n",
    "tensors_to_evaluate=[predictions, transcript, transcript_len]\n",
    "from functools import partial\n",
    "from nemo.collections.asr.helpers import monitor_asr_train_progress\n",
    "train_callback = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[loss]+tensors_to_evaluate,\n",
    "    print_func=partial(monitor_asr_train_progress, labels=labels))\n",
    "nf.train(tensors_to_optimize=[loss],\n",
    "                callbacks=[train_callback],\n",
    "                optimizer=\"novograd\",\n",
    "                optimization_params={\"num_epochs\": 1, \"lr\": 1e-2,\n",
    "                                    \"weight_decay\": 1e-3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}